{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "1b4bac3b",
   "metadata": {
    "id": "1b4bac3b",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Torch-Rechub Tutorial：Matching\n",
    "\n",
    "- 场景：召回\n",
    "- 模型：DSSM、YouTubeDNN\n",
    "- 数据：MovieLens-1M\n",
    "\n",
    "- 本教程包括以下内容：\n",
    "    1. 在MovieLens-1M数据集上数据集训练一个DSSM召回模型\n",
    "    2. 在MovieLens-1M数据集上数据集训练一个YouTubeDNN召回模型\n",
    "\n",
    "- 在阅读本教程前，希望你对YouTubeDNN和DSSM有一个初步的了解，大概了解数据在模型中是如何流动的，否则直接怼代码可能一脸懵逼。模型介绍：[YouTubeDNN](https://datawhalechina.github.io/fun-rec/#/ch02/ch2.1/ch2.1.2/YoutubeDNN)    [DSSM](https://datawhalechina.github.io/fun-rec/#/ch02/ch2.1/ch2.1.2/DSSM)\n",
    "- 本教程是对`examples/matching/run_ml_dssm.py`和`examples/matching/run_ml_youtube_dnn.py`两个文件的更详细的解释，代码基本与两个文件一致。\n",
    "- 本框架还在开发阶段，可能还有一些bug。如果你在复现后，发现指标明显高于当前，欢迎与我们交流。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "14e45fb3",
   "metadata": {
    "id": "14e45fb3",
    "outputId": "d7440817-0c09-4306-8d02-617f9e6f4f6f",
    "pycharm": {
     "name": "#%%\n"
    },
    "ExecuteTime": {
     "end_time": "2023-11-30T14:46:37.519097700Z",
     "start_time": "2023-11-30T14:46:37.485098400Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": "<torch._C.Generator at 0x146b0729510>"
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import os\n",
    "pd.set_option('display.max_rows',500)\n",
    "pd.set_option('display.max_columns',500)\n",
    "pd.set_option('display.width',1000)\n",
    "torch.manual_seed(2022)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "815SCcTVlFg-",
   "metadata": {
    "id": "815SCcTVlFg-",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### MovieLens数据集\n",
    "- MovieLens数据集是电影网站提供的一份数据集，[原数据](https://grouplens.org/datasets/movielens/1m/)分为三个文件，users.dat movies.dat ratings.dat，包含了用户信息、电影信息和用户对电影的评分信息。\n",
    "\n",
    "- 提供原始数据处理之后（参考examples/matching/data/ml-1m/preprocess_ml.py），全量数据集[**ml-1m.csv**](https://cowtransfer.com/s/5a3ab69ebd314e)\n",
    "\n",
    "- 采样后的**ml-1m_sample.csv**(examples/matching/data/ml-1m/ml-1m_sample.csv)，是在全量数据中取出的前100个样本，调试用。在大家用ml-1m_sample.csv跑通代码后，便可以下载全量数据集测试效果，共100万个样本。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "80b073d9",
   "metadata": {
    "id": "80b073d9",
    "outputId": "a07bec90-d8a5-41a3-dd0d-a737a4dfb8e5",
    "pycharm": {
     "name": "#%%\n"
    },
    "ExecuteTime": {
     "end_time": "2023-11-30T14:46:37.577111700Z",
     "start_time": "2023-11-30T14:46:37.494100400Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   user_id  movie_id  rating  timestamp                                   title                        genres gender  age  occupation    zip\n",
      "0        1      1193       5  978300760  One Flew Over the Cuckoo's Nest (1975)                         Drama      F    1          10  48067\n",
      "1        1       661       3  978302109        James and the Giant Peach (1996)  Animation|Children's|Musical      F    1          10  48067\n",
      "2        1       914       3  978301968                     My Fair Lady (1964)               Musical|Romance      F    1          10  48067\n",
      "3        1      3408       4  978300275                  Erin Brockovich (2000)                         Drama      F    1          10  48067\n",
      "4        1      2355       5  978824291                    Bug's Life, A (1998)   Animation|Children's|Comedy      F    1          10  48067\n"
     ]
    }
   ],
   "source": [
    "# sample中只有两个用户\n",
    "file_path = '../examples/matching/data/ml-1m/ml-1m_sample.csv'\n",
    "data = pd.read_csv(file_path)\n",
    "print(data.head())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "YF5Ouj7clFhB",
   "metadata": {
    "id": "YF5Ouj7clFhB",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### 在MovieLens-1M数据集上数据集训练一个DSSM模型\n",
    "\n",
    "[DSSM 论文链接](https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/cikm2013_DSSM_fullversion.pdf)\n",
    "\n",
    "#### 特征预处理\n",
    "在本DSSM模型中，我们使用两种类别的特征，分别是稀疏特征（SparseFeature）和序列特征（SequenceFeature）。\n",
    "\n",
    "- 对于稀疏特征，是一个离散的、有限的值（例如用户ID，一般会先进行LabelEncoding操作转化为连续整数值），模型将其输入到Embedding层，输出一个Embedding向量。\n",
    "\n",
    "- 对于序列特征，每一个样本是一个List[SparseFeature]（一般是观看历史、搜索历史等），对于这种特征，默认对于每一个元素取Embedding后平均，输出一个Embedding向量。此外，除了平均，还有拼接，最值等方式，可以在pooling参数中指定。\n",
    "\n",
    "- 框架还支持稠密特征（DenseFeature），即一个连续的特征值（例如概率），这种类型一般需归一化处理。但是本样例中未使用。\n",
    "\n",
    "以上三类特征的定义在`torch_rechub/basic/features.py`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "zPdETFQWlFhB",
   "metadata": {
    "id": "zPdETFQWlFhB",
    "pycharm": {
     "name": "#%%\n"
    },
    "ExecuteTime": {
     "end_time": "2023-11-30T14:46:37.579111900Z",
     "start_time": "2023-11-30T14:46:37.512100400Z"
    }
   },
   "outputs": [],
   "source": [
    "# 处理genres特征，取出其第一个作为标签\n",
    "data[\"cate_id\"] = data[\"genres\"].apply(lambda x: x.split(\"|\")[0])\n",
    "\n",
    "# 指定用户列和物品列的名字、离散和稠密特征，适配框架的接口\n",
    "user_col, item_col = \"user_id\", \"movie_id\"\n",
    "sparse_features = ['user_id', 'movie_id', 'gender', 'age', 'occupation', 'zip', \"cate_id\"]\n",
    "dense_features = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "CXAO69hclFhC",
   "metadata": {
    "id": "CXAO69hclFhC",
    "outputId": "f68453c4-b1aa-4728-c929-1ec883e4ed6f",
    "pycharm": {
     "name": "#%%\n"
    },
    "ExecuteTime": {
     "end_time": "2023-11-30T14:46:37.580111800Z",
     "start_time": "2023-11-30T14:46:37.527100300Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   user_id  movie_id gender  age  occupation    zip    cate_id\n",
      "0        1      1193      F    1          10  48067      Drama\n",
      "1        1       661      F    1          10  48067  Animation\n",
      "2        1       914      F    1          10  48067    Musical\n",
      "3        1      3408      F    1          10  48067      Drama\n",
      "4        1      2355      F    1          10  48067  Animation\n",
      "LabelEncoding后：\n",
      "   user_id  movie_id  gender  age  occupation  zip  cate_id\n",
      "0        1        32       1    1           1    1        7\n",
      "1        1        17       1    1           1    1        3\n",
      "2        1        22       1    1           1    1        8\n",
      "3        1        91       1    1           1    1        7\n",
      "4        1        66       1    1           1    1        3\n"
     ]
    }
   ],
   "source": [
    "save_dir = '../examples/ranking/data/ml-1m/saved/'\n",
    "if not os.path.exists(save_dir):\n",
    "    os.makedirs(save_dir)\n",
    "# 对SparseFeature进行LabelEncoding\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "print(data[sparse_features].head())\n",
    "feature_max_idx = {}\n",
    "for feature in sparse_features:\n",
    "    lbe = LabelEncoder()\n",
    "    data[feature] = lbe.fit_transform(data[feature]) + 1\n",
    "    feature_max_idx[feature] = data[feature].max() + 1\n",
    "    if feature == user_col:\n",
    "        user_map = {encode_id + 1: raw_id for encode_id, raw_id in enumerate(lbe.classes_)}  #encode user id: raw user id\n",
    "    if feature == item_col:\n",
    "        item_map = {encode_id + 1: raw_id for encode_id, raw_id in enumerate(lbe.classes_)}  #encode item id: raw item id\n",
    "np.save(save_dir+\"raw_id_maps.npy\", (user_map, item_map))  # evaluation时会用到\n",
    "print('LabelEncoding后：')\n",
    "print(data[sparse_features].head())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ccBqvE0KlFhD",
   "metadata": {
    "id": "ccBqvE0KlFhD",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "#### 用户塔与物品塔\n",
    "在DSSM中，分为用户塔和物品塔，每一个塔的输出是用户/物品的特征拼接后经过MLP（多层感知机）得到的。\n",
    "下面我们来定义物品塔和用户塔都有哪些特征："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "ofoMu3SKlFhD",
   "metadata": {
    "id": "ofoMu3SKlFhD",
    "outputId": "a4673ec1-74eb-43ab-edde-002a66901b64",
    "pycharm": {
     "name": "#%%\n"
    },
    "ExecuteTime": {
     "end_time": "2023-11-30T14:46:37.580111800Z",
     "start_time": "2023-11-30T14:46:37.547098900Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    user_id  gender  age  occupation  zip\n",
      "0         1       1    1           1    1\n",
      "53        2       2    2           2    2\n",
      "   movie_id  cate_id\n",
      "0        32        7\n",
      "1        17        3\n",
      "2        22        8\n",
      "3        91        7\n",
      "4        66        3\n"
     ]
    }
   ],
   "source": [
    "# 定义两个塔对应哪些特征\n",
    "user_cols = [\"user_id\", \"gender\", \"age\", \"occupation\", \"zip\"]\n",
    "item_cols = ['movie_id', \"cate_id\"]\n",
    "\n",
    "# 从data中取出相应的数据\n",
    "user_profile = data[user_cols].drop_duplicates('user_id')\n",
    "item_profile = data[item_cols].drop_duplicates('movie_id')\n",
    "print(user_profile.head())\n",
    "print(item_profile.head())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "AXcl5RinlFhE",
   "metadata": {
    "id": "AXcl5RinlFhE",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "#### 序列特征的处理\n",
    "本数据集中的序列特征为观看历史，根据timestamp来生成，具体在`generate_seq_feature_match`函数中实现。参数含义如下：\n",
    "- `mode`表示样本的训练方式（0 - point wise, 1 - pair wise, 2 - list wise）\n",
    "- `neg_ratio`表示每个正样本对应的负样本数量，\n",
    "- `min_item`限制每个用户最少的样本量，小于此值将会被抛弃，当做冷启动用户处理（框架中还未添加冷启动的处理，这里直接抛弃）\n",
    "- `sample_method`表示负采样方法。\n",
    "\n",
    "> 关于参数`mode`的一点小说明：在模型实现过程中，框架只考虑了论文中提出的样本的训练方式，用其他方式可能会报错。例如：DSSM中采用point wise的方式，即`mode=0`，如果传入别的`mode`，不保证能正确运行，但是论文中对应的训练方式是能保证正确运行的。\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "XZ4jzoiulFhE",
   "metadata": {
    "id": "XZ4jzoiulFhE",
    "outputId": "ead082af-3809-4682-a7c2-769894bc2727",
    "pycharm": {
     "name": "#%%\n"
    },
    "ExecuteTime": {
     "end_time": "2023-11-30T14:46:37.627625Z",
     "start_time": "2023-11-30T14:46:37.556100100Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "preprocess data\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "generate sequence features: 100%|██████████| 2/2 [00:00<00:00, 999.48it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "n_train: 384, n_test: 2\n",
      "0 cold start user droped \n",
      "   user_id  movie_id                                      hist_movie_id  histlen_movie_id  label\n",
      "0        2        48  [35, 37, 43, 32, 78, 36, 34, 92, 3, 79, 86, 82...                40      0\n",
      "1        2        49  [35, 37, 43, 32, 78, 36, 34, 92, 3, 79, 86, 82...                33      1\n",
      "2        1         1  [87, 51, 25, 41, 65, 53, 91, 34, 74, 32, 5, 18...                34      0\n",
      "3        1        19        [87, 51, 25, 41, 65, 53, 91, 34, 74, 32, 5]                11      0\n",
      "4        2        65  [35, 37, 43, 32, 78, 36, 34, 92, 3, 79, 86, 82...                30      0\n",
      "{'user_id': array([2, 2, 1]), 'movie_id': array([48, 49,  1]), 'hist_movie_id': array([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0, 35, 37, 43, 32, 78, 36,\n",
      "        34, 92,  3, 79, 86, 82, 44, 56, 40, 21, 30, 93, 80, 81, 39, 61,\n",
      "        60, 62, 88, 15, 38, 45, 31, 64, 84, 58, 76, 49, 89, 16, 52, 83,\n",
      "         7, 75],\n",
      "       [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n",
      "         0, 35, 37, 43, 32, 78, 36, 34, 92,  3, 79, 86, 82, 44, 56, 40,\n",
      "        21, 30, 93, 80, 81, 39, 61, 60, 62, 88, 15, 38, 45, 31, 64, 84,\n",
      "        58, 76],\n",
      "       [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n",
      "        87, 51, 25, 41, 65, 53, 91, 34, 74, 32,  5, 18, 23, 14, 70, 55,\n",
      "        58, 82, 24, 28, 56, 57,  4, 26, 29, 22, 42, 73, 71, 38, 17, 77,\n",
      "        10, 85]]), 'histlen_movie_id': array([40, 33, 34]), 'label': array([0, 1, 0]), 'gender': array([2, 2, 1]), 'age': array([2, 2, 1]), 'occupation': array([2, 2, 1]), 'zip': array([2, 2, 1]), 'cate_id': array([2, 1, 3])}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from torch_rechub.utils.match import generate_seq_feature_match, gen_model_input\n",
    "df_train, df_test = generate_seq_feature_match(data,\n",
    "                                               user_col,\n",
    "                                               item_col,\n",
    "                                               time_col=\"timestamp\",\n",
    "                                               item_attribute_cols=[],\n",
    "                                               sample_method=1,\n",
    "                                               mode=0,\n",
    "                                               neg_ratio=3,\n",
    "                                               min_item=0)\n",
    "print(df_train.head())\n",
    "x_train = gen_model_input(df_train, user_profile, user_col, item_profile, item_col, seq_max_len=50)\n",
    "y_train = x_train[\"label\"]\n",
    "x_test = gen_model_input(df_test, user_profile, user_col, item_profile, item_col, seq_max_len=50)\n",
    "y_test = x_test[\"label\"]\n",
    "print({k: v[:3] for k, v in x_train.items()})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "scu7A3gflFhF",
   "metadata": {
    "id": "scu7A3gflFhF",
    "outputId": "cddb3abb-f777-4bb1-a964-b3aa14208fa1",
    "pycharm": {
     "name": "#%%\n"
    },
    "ExecuteTime": {
     "end_time": "2023-11-30T14:46:37.726138400Z",
     "start_time": "2023-11-30T14:46:37.591621400Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[<SparseFeature user_id with Embedding shape (3, 16)>, <SparseFeature gender with Embedding shape (3, 16)>, <SparseFeature age with Embedding shape (3, 16)>, <SparseFeature occupation with Embedding shape (3, 16)>, <SparseFeature zip with Embedding shape (3, 16)>, <SequenceFeature hist_movie_id with Embedding shape (94, 16)>]\n",
      "[<SparseFeature movie_id with Embedding shape (94, 16)>, <SparseFeature cate_id with Embedding shape (11, 16)>]\n"
     ]
    }
   ],
   "source": [
    "#定义特征类型\n",
    "\n",
    "from torch_rechub.basic.features import SparseFeature, SequenceFeature\n",
    "user_features = [\n",
    "    SparseFeature(feature_name, vocab_size=feature_max_idx[feature_name], embed_dim=16) for feature_name in user_cols\n",
    "]\n",
    "user_features += [\n",
    "    SequenceFeature(\"hist_movie_id\",\n",
    "                    vocab_size=feature_max_idx[\"movie_id\"],\n",
    "                    embed_dim=16,\n",
    "                    pooling=\"mean\",\n",
    "                    shared_with=\"movie_id\")\n",
    "]\n",
    "\n",
    "item_features = [\n",
    "    SparseFeature(feature_name, vocab_size=feature_max_idx[feature_name], embed_dim=16) for feature_name in item_cols\n",
    "]\n",
    "print(user_features)\n",
    "print(item_features)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "IqShxlo3lFhF",
   "metadata": {
    "id": "IqShxlo3lFhF",
    "outputId": "adb85bf8-303c-4261-b11f-0742cc6910fb",
    "pycharm": {
     "name": "#%%\n"
    },
    "ExecuteTime": {
     "end_time": "2023-11-30T14:46:37.743140400Z",
     "start_time": "2023-11-30T14:46:37.616622400Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'movie_id': array([32, 17, 22]), 'cate_id': array([7, 3, 8])}\n",
      "{'user_id': 1, 'movie_id': 2, 'hist_movie_id': array([25, 41, 65, 53, 91, 34, 74, 32,  5, 18, 23, 14, 70, 55, 58, 82, 24,\n",
      "       28, 56, 57,  4, 26, 29, 22, 42, 73, 71, 38, 17, 77, 10, 85, 72, 64,\n",
      "       27, 33, 12, 67, 47,  9, 13,  1, 69, 19, 11, 20, 66, 63, 48, 54]), 'histlen_movie_id': 52, 'label': 1, 'gender': 1, 'age': 1, 'occupation': 1, 'zip': 1, 'cate_id': 3}\n"
     ]
    }
   ],
   "source": [
    "# 将dataframe转为dict\n",
    "from torch_rechub.utils.data import df_to_dict\n",
    "all_item = df_to_dict(item_profile)\n",
    "test_user = x_test\n",
    "print({k: v[:3] for k, v in all_item.items()})\n",
    "print({k: v[0] for k, v in test_user.items()})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fc52e9a5",
   "metadata": {
    "id": "fc52e9a5",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### 训练模型\n",
    "\n",
    "- 根据之前的x_train字典和y_train等数据生成训练用的Dataloader（train_dl）、测试用的Dataloader（test_dl, item_dl）。\n",
    "\n",
    "- 定义一个双塔DSSM模型，`user_features`表示用户塔有哪些特征，`user_params`表示用户塔的MLP的各层维度和激活函数。（Note：在这个样例中激活函数的选取对最终结果影响很大，调试时不要修改激活函数的参数）\n",
    "- 定义一个召回训练器MatchTrainer，进行模型的训练。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "4fb791e6",
   "metadata": {
    "id": "4fb791e6",
    "outputId": "557de275-42df-451d-94b1-15990999f80d",
    "pycharm": {
     "name": "#%%\n"
    },
    "ExecuteTime": {
     "end_time": "2023-11-30T14:46:53.548094900Z",
     "start_time": "2023-11-30T14:46:37.621622700Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train: 100%|██████████| 2/2 [00:15<00:00,  7.96s/it]\n"
     ]
    }
   ],
   "source": [
    "from torch_rechub.models.matching import DSSM\n",
    "from torch_rechub.trainers import MatchTrainer\n",
    "from torch_rechub.utils.data import MatchDataGenerator\n",
    "# 根据之前处理的数据拿到Dataloader\n",
    "dg = MatchDataGenerator(x=x_train, y=y_train)\n",
    "train_dl, test_dl, item_dl = dg.generate_dataloader(test_user, all_item, batch_size=256)\n",
    "\n",
    "# 定义模型\n",
    "model = DSSM(user_features,\n",
    "             item_features,\n",
    "             temperature=0.02,\n",
    "             user_params={\n",
    "                 \"dims\": [256, 128, 64],\n",
    "                 \"activation\": 'prelu',  # important!!\n",
    "             },\n",
    "             item_params={\n",
    "                 \"dims\": [256, 128, 64],\n",
    "                 \"activation\": 'prelu',  # important!!\n",
    "             })\n",
    "\n",
    "# 模型训练器\n",
    "trainer = MatchTrainer(model,\n",
    "                       mode=0,  # 同上面的mode，需保持一致\n",
    "                       optimizer_params={\n",
    "                           \"lr\": 1e-4,\n",
    "                           \"weight_decay\": 1e-6\n",
    "                       },\n",
    "                       n_epoch=1,\n",
    "                       device='cpu',\n",
    "                       model_path=save_dir)\n",
    "\n",
    "# 开始训练\n",
    "trainer.fit(train_dl)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "or-2bVFslFhG",
   "metadata": {
    "id": "or-2bVFslFhG",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### 向量化召回 评估\n",
    "- 使用trainer获取测试集中每个user的embedding和数据集中所有物品的embedding集合\n",
    "- 用annoy构建物品embedding索引，对每个用户向量进行ANN（Approximate Nearest Neighbors）召回K个物品\n",
    "- 查看topk评估指标，一般看recall、precision、hit\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "3866e820",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    },
    "ExecuteTime": {
     "end_time": "2023-11-30T14:46:53.566095Z",
     "start_time": "2023-11-30T14:46:53.553095300Z"
    }
   },
   "outputs": [],
   "source": [
    "import collections\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from torch_rechub.utils.match import Annoy\n",
    "from torch_rechub.basic.metric import topk_metrics\n",
    "\n",
    "def match_evaluation(user_embedding, item_embedding, test_user, all_item, user_col='user_id', item_col='movie_id',\n",
    "                     raw_id_maps=\"./raw_id_maps.npy\", topk=10):\n",
    "    print(\"evaluate embedding matching on test data\")\n",
    "    annoy = Annoy(n_trees=10)\n",
    "    annoy.fit(item_embedding)\n",
    "\n",
    "    #for each user of test dataset, get ann search topk result\n",
    "    print(\"matching for topk\")\n",
    "    user_map, item_map = np.load(raw_id_maps, allow_pickle=True)\n",
    "    match_res = collections.defaultdict(dict)  # user id -> predicted item ids\n",
    "    for user_id, user_emb in zip(test_user[user_col], user_embedding):\n",
    "        items_idx, items_scores = annoy.query(v=user_emb, n=topk)  #the index of topk match items\n",
    "        match_res[user_map[user_id]] = np.vectorize(item_map.get)(all_item[item_col][items_idx])\n",
    "\n",
    "    #get ground truth\n",
    "    print(\"generate ground truth\")\n",
    "\n",
    "    data = pd.DataFrame({user_col: test_user[user_col], item_col: test_user[item_col]})\n",
    "    data[user_col] = data[user_col].map(user_map)\n",
    "    data[item_col] = data[item_col].map(item_map)\n",
    "    user_pos_item = data.groupby(user_col).agg(list).reset_index()\n",
    "    ground_truth = dict(zip(user_pos_item[user_col], user_pos_item[item_col]))  # user id -> ground truth\n",
    "\n",
    "    print(\"compute topk metrics\")\n",
    "    out = topk_metrics(y_true=ground_truth, y_pred=match_res, topKs=[topk])\n",
    "    return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "CKfpMh0vlFhG",
   "metadata": {
    "id": "CKfpMh0vlFhG",
    "outputId": "a722cd9d-1909-423b-a7d1-52e84a6c10e3",
    "pycharm": {
     "name": "#%%\n"
    },
    "ExecuteTime": {
     "end_time": "2023-11-30T14:47:02.176867900Z",
     "start_time": "2023-11-30T14:46:53.570095100Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "user inference: 100%|██████████| 1/1 [00:04<00:00,  4.72s/it]\n",
      "item inference: 100%|██████████| 1/1 [00:03<00:00,  3.85s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "evaluate embedding matching on test data\n",
      "matching for topk\n",
      "generate ground truth\n",
      "compute topk metrics\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": "defaultdict(list,\n            {'NDCG': ['NDCG@10: 0.0'],\n             'MRR': ['MRR@10: 0.0'],\n             'Recall': ['Recall@10: 0.0'],\n             'Hit': ['Hit@10: 0.0'],\n             'Precision': ['Precision@10: 0.0']})"
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "user_embedding = trainer.inference_embedding(model=model, mode=\"user\", data_loader=test_dl, model_path=save_dir)\n",
    "item_embedding = trainer.inference_embedding(model=model, mode=\"item\", data_loader=item_dl, model_path=save_dir)\n",
    "match_evaluation(user_embedding, item_embedding, test_user, all_item, topk=10, raw_id_maps=save_dir+\"raw_id_maps.npy\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "yA0ORyDTlFhH",
   "metadata": {
    "id": "yA0ORyDTlFhH",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### 在MovieLens-1M数据集上数据集训练一个YouTubeDNN模型\n",
    "\n",
    "[YoutubeDNN论文链接](https://dl.acm.org/doi/pdf/10.1145/2959100.2959190)\n",
    "\n",
    "YouTubeDNN模型虽然叫单塔模型，但也是以双塔模型的思想去构建的，所以不管是模型还是其他都很相似。\n",
    "下面给出了YouTubeDNN的代码，与DSSM不同的代码会用`序号+中文`的方式标注，例如`# [0]训练方式改为List wise`，大家可以感受一下二者的区别。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "oM0gkV2KlFhH",
   "metadata": {
    "id": "oM0gkV2KlFhH",
    "outputId": "d5077600-b9ca-4cf4-a041-7d1e04b799da",
    "pycharm": {
     "name": "#%%\n"
    },
    "ExecuteTime": {
     "end_time": "2023-11-30T14:47:44.765085100Z",
     "start_time": "2023-11-30T14:47:33.160242700Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "preprocess data\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "generate sequence features: 100%|██████████| 2/2 [00:00<00:00, 1997.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "n_train: 96, n_test: 2\n",
      "0 cold start user droped \n",
      "epoch: 0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train: 100%|██████████| 1/1 [00:04<00:00,  4.02s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "inference embedding\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "user inference: 100%|██████████| 1/1 [00:03<00:00,  3.91s/it]\n",
      "item inference: 100%|██████████| 1/1 [00:03<00:00,  3.61s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "evaluate embedding matching on test data\n",
      "matching for topk\n",
      "generate ground truth\n",
      "compute topk metrics\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": "defaultdict(list,\n            {'NDCG': ['NDCG@10: 0.0'],\n             'MRR': ['MRR@10: 0.0'],\n             'Recall': ['Recall@10: 0.0'],\n             'Hit': ['Hit@10: 0.0'],\n             'Precision': ['Precision@10: 0.0']})"
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "from torch_rechub.models.matching import YoutubeDNN\n",
    "from torch_rechub.trainers import MatchTrainer\n",
    "from torch_rechub.basic.features import SparseFeature, SequenceFeature\n",
    "from torch_rechub.utils.match import generate_seq_feature_match, gen_model_input\n",
    "from torch_rechub.utils.data import df_to_dict, MatchDataGenerator\n",
    "\n",
    "\n",
    "torch.manual_seed(2022)\n",
    "\n",
    "data = pd.read_csv(file_path)\n",
    "data[\"cate_id\"] = data[\"genres\"].apply(lambda x: x.split(\"|\")[0])\n",
    "sparse_features = ['user_id', 'movie_id', 'gender', 'age', 'occupation', 'zip', \"cate_id\"]\n",
    "user_col, item_col = \"user_id\", \"movie_id\"\n",
    "\n",
    "feature_max_idx = {}\n",
    "for feature in sparse_features:\n",
    "    lbe = LabelEncoder()\n",
    "    data[feature] = lbe.fit_transform(data[feature]) + 1\n",
    "    feature_max_idx[feature] = data[feature].max() + 1\n",
    "    if feature == user_col:\n",
    "        user_map = {encode_id + 1: raw_id for encode_id, raw_id in enumerate(lbe.classes_)}  #encode user id: raw user id\n",
    "    if feature == item_col:\n",
    "        item_map = {encode_id + 1: raw_id for encode_id, raw_id in enumerate(lbe.classes_)}  #encode item id: raw item id\n",
    "np.save(save_dir+\"raw_id_maps.npy\", (user_map, item_map))\n",
    "user_cols = [\"user_id\", \"gender\", \"age\", \"occupation\", \"zip\"]\n",
    "item_cols = [\"movie_id\", \"cate_id\"]\n",
    "user_profile = data[user_cols].drop_duplicates('user_id')\n",
    "item_profile = data[item_cols].drop_duplicates('movie_id')\n",
    "\n",
    "\n",
    "#Note: mode=2 means list-wise negative sample generate, saved in last col \"neg_items\"\n",
    "df_train, df_test = generate_seq_feature_match(data,\n",
    "                                               user_col,\n",
    "                                               item_col,\n",
    "                                               time_col=\"timestamp\",\n",
    "                                               item_attribute_cols=[],\n",
    "                                               sample_method=1,\n",
    "                                               mode=2,  # [0]训练方式改为List wise\n",
    "                                               neg_ratio=3,\n",
    "                                               min_item=0)\n",
    "x_train = gen_model_input(df_train, user_profile, user_col, item_profile, item_col, seq_max_len=50)\n",
    "y_train = np.array([0] * df_train.shape[0])  # [1]训练集所有样本的label都取0。因为一个样本的组成是(pos, neg1, neg2, ...)，视为一个多分类任务，正样本的位置永远是0\n",
    "x_test = gen_model_input(df_test, user_profile, user_col, item_profile, item_col, seq_max_len=50)\n",
    "\n",
    "user_cols = ['user_id', 'gender', 'age', 'occupation', 'zip']\n",
    "\n",
    "user_features = [SparseFeature(name, vocab_size=feature_max_idx[name], embed_dim=16) for name in user_cols]\n",
    "user_features += [\n",
    "    SequenceFeature(\"hist_movie_id\",\n",
    "                    vocab_size=feature_max_idx[\"movie_id\"],\n",
    "                    embed_dim=16,\n",
    "                    pooling=\"mean\",\n",
    "                    shared_with=\"movie_id\")\n",
    "]\n",
    "\n",
    "item_features = [SparseFeature('movie_id', vocab_size=feature_max_idx['movie_id'], embed_dim=16)]  # [2]物品的特征只有itemID，即movie_id一个\n",
    "neg_item_feature = [\n",
    "    SequenceFeature('neg_items',\n",
    "                    vocab_size=feature_max_idx['movie_id'],\n",
    "                    embed_dim=16,\n",
    "                    pooling=\"concat\",\n",
    "                    shared_with=\"movie_id\")\n",
    "]  # [3] 多了一个neg item feature，会传入到模型中，在item tower中会用到\n",
    "\n",
    "all_item = df_to_dict(item_profile)\n",
    "test_user = x_test\n",
    "\n",
    "dg = MatchDataGenerator(x=x_train, y=y_train)\n",
    "train_dl, test_dl, item_dl = dg.generate_dataloader(test_user, all_item, batch_size=256)\n",
    "\n",
    "model = YoutubeDNN(user_features, item_features, neg_item_feature, user_params={\"dims\": [128, 64, 16]}, temperature=0.02)  # [4] MLP的最后一层需保持与item embedding一致\n",
    "\n",
    "#mode=1 means pair-wise learning\n",
    "trainer = MatchTrainer(model,\n",
    "                       mode=2,\n",
    "                       optimizer_params={\n",
    "                           \"lr\": 1e-4,\n",
    "                           \"weight_decay\": 1e-6\n",
    "                       },\n",
    "                       n_epoch=1, #5\n",
    "                       device='cpu',\n",
    "                       model_path=save_dir)\n",
    "\n",
    "trainer.fit(train_dl)\n",
    "\n",
    "print(\"inference embedding\")\n",
    "user_embedding = trainer.inference_embedding(model=model, mode=\"user\", data_loader=test_dl, model_path=save_dir)\n",
    "item_embedding = trainer.inference_embedding(model=model, mode=\"item\", data_loader=item_dl, model_path=save_dir)\n",
    "match_evaluation(user_embedding, item_embedding, test_user, all_item, topk=10, raw_id_maps=\"../examples/ranking/data/ml-1m/saved/raw_id_maps.npy\")\n"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "name": "Torch-Rechub Tutorial：Matching.ipynb",
   "provenance": []
  },
  "interpreter": {
   "hash": "2f0699014af7f4c9080a159fe6ab9f0087a283cb8192b31d41a414a088fd29ff"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
